{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "934b74e1",
   "metadata": {},
   "source": [
    "# NAACL'21 DLG4NLP Tutorial Demo: Semantic Parsing"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "767c3b2e",
   "metadata": {},
   "source": [
    "In this tutorial demo, we will use the Graph4NLP library to build a GNN-based semantic parsing model. The model consists of \n",
    "- graph construction module (e.g., node embedding based dynamic graph)\n",
    "- graph embedding module (e.g., Bi-Sep GAT)\n",
    "- predictoin module (e.g., RNN decoder with attention, copy and coverage mechanisms)\n",
    "\n",
    "We will use the built-in Graph2Seq model APIs to build the model, and evaluate it on the Jobs dataset."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ac004ea7",
   "metadata": {},
   "source": [
    "### Environment setup"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1de88301",
   "metadata": {},
   "source": [
    "1. Create virtual environment\n",
    "```\n",
    "conda create --name graph4nlp python=3.7\n",
    "conda activate graph4nlp\n",
    "```\n",
    "\n",
    "2. Install [graph4nlp](https://github.com/graph4ai/graph4nlp) library\n",
    "- Clone the github repo\n",
    "```\n",
    "git clone -b stable https://github.com/graph4ai/graph4nlp.git\n",
    "cd graph4nlp\n",
    "```\n",
    "- Then run `./configure` (or `./configure.bat` if you are using Windows 10) to config your installation. The configuration program will ask you to specify your CUDA version. If you do not have a GPU, please choose 'cpu'.\n",
    "```\n",
    "./configure\n",
    "```\n",
    "- Finally, install the package\n",
    "```\n",
    "python setup.py install\n",
    "```\n",
    "\n",
    "3. Set up StanfordCoreNLP (for static graph construction only, unnecessary for this demo because preprocessed data is provided)\n",
    "- Download [StanfordCoreNLP](https://stanfordnlp.github.io/CoreNLP/)\n",
    "- Go to the root folder and start the server\n",
    "```\n",
    "java -mx4g -cp \"*\" edu.stanford.nlp.pipeline.StanfordCoreNLPServer -port 9000 -timeout 15000\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "086ca306",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using backend: pytorch\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import yaml\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "from graph4nlp.pytorch.datasets.jobs import JobsDataset\n",
    "from graph4nlp.pytorch.models.graph2seq import Graph2Seq\n",
    "from graph4nlp.pytorch.models.graph2seq_loss import Graph2SeqLoss\n",
    "from graph4nlp.pytorch.modules.graph_construction import *\n",
    "from graph4nlp.pytorch.modules.evaluation.base import EvaluationMetricBase\n",
    "from graph4nlp.pytorch.modules.utils.config_utils import update_values\n",
    "from graph4nlp.pytorch.modules.utils.copy_utils import prepare_ext_vocab\n",
    "from graph4nlp.pytorch.modules.config import get_basic_args"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b1637d23",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Jobs:\n",
    "    def __init__(self, opt):\n",
    "        super(Jobs, self).__init__()\n",
    "        self.opt = opt\n",
    "        self.use_copy = self.opt[\"decoder_args\"][\"rnn_decoder_share\"][\"use_copy\"]\n",
    "        self.use_coverage = self.opt[\"decoder_args\"][\"rnn_decoder_share\"][\"use_coverage\"]\n",
    "        self._build_device(self.opt)\n",
    "        self._build_dataloader()\n",
    "        self._build_model()\n",
    "        self._build_loss_function()\n",
    "        self._build_optimizer()\n",
    "        self._build_evaluation()\n",
    "\n",
    "    def _build_device(self, opt):\n",
    "        seed = opt[\"seed\"]\n",
    "        np.random.seed(seed)\n",
    "        if opt[\"use_gpu\"] != 0 and torch.cuda.is_available():\n",
    "            print('[ Using CUDA ]')\n",
    "            torch.manual_seed(seed)\n",
    "            torch.cuda.manual_seed_all(seed)\n",
    "            from torch.backends import cudnn\n",
    "            cudnn.benchmark = True\n",
    "            device = torch.device('cuda' if opt[\"gpu\"] < 0 else 'cuda:%d' % opt[\"gpu\"])\n",
    "        else:\n",
    "            print('[ Using CPU ]')\n",
    "            device = torch.device('cpu')\n",
    "        self.device = device\n",
    "\n",
    "    def _build_dataloader(self):\n",
    "        if self.opt[\"graph_construction_args\"][\"graph_construction_share\"][\"graph_type\"] == \"dependency\":\n",
    "            topology_builder = DependencyBasedGraphConstruction\n",
    "            graph_type = 'static'\n",
    "            dynamic_init_topology_builder = None\n",
    "        elif self.opt[\"graph_construction_args\"][\"graph_construction_share\"][\"graph_type\"] == \"constituency\":\n",
    "            topology_builder = ConstituencyBasedGraphConstruction\n",
    "            graph_type = 'static'\n",
    "            dynamic_init_topology_builder = None\n",
    "        elif self.opt[\"graph_construction_args\"][\"graph_construction_share\"][\"graph_type\"] == \"node_emb\":\n",
    "            topology_builder = NodeEmbeddingBasedGraphConstruction\n",
    "            graph_type = 'dynamic'\n",
    "            dynamic_init_topology_builder = None\n",
    "        elif self.opt[\"graph_construction_args\"][\"graph_construction_share\"][\"graph_type\"] == \"node_emb_refined\":\n",
    "            topology_builder = NodeEmbeddingBasedRefinedGraphConstruction\n",
    "            graph_type = 'dynamic'\n",
    "            dynamic_init_graph_type = self.opt[\"graph_construction_args\"][\"graph_construction_private\"][\n",
    "                \"dynamic_init_graph_type\"]\n",
    "            if dynamic_init_graph_type is None or dynamic_init_graph_type == 'line':\n",
    "                dynamic_init_topology_builder = None\n",
    "            elif dynamic_init_graph_type == 'dependency':\n",
    "                dynamic_init_topology_builder = DependencyBasedGraphConstruction\n",
    "            elif dynamic_init_graph_type == 'constituency':\n",
    "                dynamic_init_topology_builder = ConstituencyBasedGraphConstruction\n",
    "            else:\n",
    "                raise RuntimeError('Define your own dynamic_init_topology_builder')\n",
    "        else:\n",
    "            raise NotImplementedError(\"Define your topology builder.\")\n",
    "\n",
    "        dataset = JobsDataset(root_dir=self.opt[\"graph_construction_args\"][\"graph_construction_share\"][\"root_dir\"],\n",
    "                              pretrained_word_emb_name=self.opt[\"pretrained_word_emb_name\"],\n",
    "                              pretrained_word_emb_url=self.opt[\"pretrained_word_emb_url\"],\n",
    "                              pretrained_word_emb_cache_dir=self.opt[\"pretrained_word_emb_cache_dir\"],\n",
    "                              merge_strategy=self.opt[\"graph_construction_args\"][\"graph_construction_private\"][\n",
    "                                  \"merge_strategy\"],\n",
    "                              edge_strategy=self.opt[\"graph_construction_args\"][\"graph_construction_private\"][\n",
    "                                  \"edge_strategy\"],\n",
    "                              seed=self.opt[\"seed\"],\n",
    "                              word_emb_size=self.opt[\"word_emb_size\"],\n",
    "                              share_vocab=self.opt[\"graph_construction_args\"][\"graph_construction_share\"][\n",
    "                                  \"share_vocab\"],\n",
    "                              graph_type=graph_type,\n",
    "                              topology_builder=topology_builder,\n",
    "                              topology_subdir=self.opt[\"graph_construction_args\"][\"graph_construction_share\"][\n",
    "                                  \"topology_subdir\"],\n",
    "                              thread_number=self.opt[\"graph_construction_args\"][\"graph_construction_share\"][\n",
    "                                  \"thread_number\"],\n",
    "                              dynamic_graph_type=self.opt[\"graph_construction_args\"][\"graph_construction_share\"][\n",
    "                                  \"graph_type\"],\n",
    "                              dynamic_init_topology_builder=dynamic_init_topology_builder,\n",
    "                              dynamic_init_topology_aux_args=None)\n",
    "\n",
    "        self.train_dataloader = DataLoader(dataset.train, batch_size=self.opt[\"batch_size\"], shuffle=True,\n",
    "                                           num_workers=1,\n",
    "                                           collate_fn=dataset.collate_fn)\n",
    "        self.test_dataloader = DataLoader(dataset.test, batch_size=self.opt[\"batch_size\"], shuffle=False, num_workers=1,\n",
    "                                          collate_fn=dataset.collate_fn)\n",
    "        self.vocab = dataset.vocab_model\n",
    "\n",
    "    def _build_model(self):\n",
    "        self.model = Graph2Seq.from_args(self.opt, self.vocab).to(self.device)\n",
    "\n",
    "    def _build_loss_function(self):\n",
    "        self.loss = Graph2SeqLoss(ignore_index=self.vocab.out_word_vocab.PAD,\n",
    "                                  use_coverage=self.use_coverage, coverage_weight=0.3)\n",
    "    def _build_optimizer(self):\n",
    "        parameters = [p for p in self.model.parameters() if p.requires_grad]\n",
    "        self.optimizer = optim.Adam(parameters, lr=self.opt[\"learning_rate\"])\n",
    "\n",
    "    def _build_evaluation(self):\n",
    "        self.metrics = [ExpressionAccuracy()]\n",
    "\n",
    "    def train(self):\n",
    "        max_score = -1\n",
    "        self._best_epoch = -1\n",
    "        for epoch in range(200):\n",
    "            self.model.train()\n",
    "            self.train_epoch(epoch, split=\"train\")\n",
    "            self._adjust_lr(epoch)\n",
    "            if epoch >= 0:\n",
    "                score = self.evaluate(split=\"test\")\n",
    "                if score >= max_score:\n",
    "                    print(\"Best model saved, epoch {}\".format(epoch))\n",
    "                    self.save_checkpoint(\"best.pth\")\n",
    "                    self._best_epoch = epoch\n",
    "                max_score = max(max_score, score)\n",
    "            if epoch >= 30 and self._stop_condition(epoch):\n",
    "                break\n",
    "        return max_score\n",
    "\n",
    "    def _stop_condition(self, epoch, patience=20):\n",
    "        return epoch > patience + self._best_epoch\n",
    "\n",
    "    def _adjust_lr(self, epoch):\n",
    "        def set_lr(optimizer, decay_factor):\n",
    "            for group in optimizer.param_groups:\n",
    "                group['lr'] = group['lr'] * decay_factor\n",
    "\n",
    "        epoch_diff = epoch - self.opt[\"lr_start_decay_epoch\"]\n",
    "        if epoch_diff >= 0 and epoch_diff % self.opt[\"lr_decay_per_epoch\"] == 0:\n",
    "            if self.opt[\"learning_rate\"] > self.opt[\"min_lr\"]:\n",
    "                set_lr(self.optimizer, self.opt[\"lr_decay_rate\"])\n",
    "                self.opt[\"learning_rate\"] = self.opt[\"learning_rate\"] * self.opt[\"lr_decay_rate\"]\n",
    "                print(\"Learning rate adjusted: {:.5f}\".format(self.opt[\"learning_rate\"]))\n",
    "\n",
    "    def train_epoch(self, epoch, split=\"train\"):\n",
    "        assert split in [\"train\"]\n",
    "        print(\"Start training in split {}, Epoch: {}\".format(split, epoch))\n",
    "        loss_collect = []\n",
    "        dataloader = self.train_dataloader\n",
    "        step_all_train = len(dataloader)\n",
    "        for step, data in enumerate(dataloader):\n",
    "            graph, tgt, gt_str = data[\"graph_data\"], data[\"tgt_seq\"], data[\"output_str\"]\n",
    "            graph = graph.to(self.device)\n",
    "            tgt = tgt.to(self.device)\n",
    "            oov_dict = None\n",
    "            if self.use_copy:\n",
    "                oov_dict, tgt = prepare_ext_vocab(graph, self.vocab, gt_str=gt_str, device=self.device)\n",
    "\n",
    "            prob, enc_attn_weights, coverage_vectors = self.model(graph, tgt, oov_dict=oov_dict)\n",
    "            loss = self.loss(logits=prob, label=tgt, enc_attn_weights=enc_attn_weights,\n",
    "                             coverage_vectors=coverage_vectors)\n",
    "            loss_collect.append(loss.item())\n",
    "            if step % self.opt[\"loss_display_step\"] == 0 and step != 0:\n",
    "                print(\"Epoch {}: [{} / {}] loss: {:.3f}\".format(epoch, step, step_all_train,\n",
    "                                                                           np.mean(loss_collect)))\n",
    "                loss_collect = []\n",
    "            self.optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            self.optimizer.step()\n",
    "\n",
    "    def evaluate(self, split=\"val\"):\n",
    "        self.model.eval()\n",
    "        pred_collect = []\n",
    "        gt_collect = []\n",
    "        assert split in [\"val\", \"test\"]\n",
    "        dataloader = self.val_dataloader if split == \"val\" else self.test_dataloader\n",
    "        for data in dataloader:\n",
    "            graph, tgt, gt_str = data[\"graph_data\"], data[\"tgt_seq\"], data[\"output_str\"]\n",
    "            graph = graph.to(self.device)\n",
    "            if self.use_copy:\n",
    "                oov_dict = prepare_ext_vocab(batch_graph=graph, vocab=self.vocab, device=self.device)\n",
    "                ref_dict = oov_dict\n",
    "            else:\n",
    "                oov_dict = None\n",
    "                ref_dict = self.vocab.out_word_vocab\n",
    "\n",
    "            prob, _, _ = self.model(graph, oov_dict=oov_dict)\n",
    "            pred = prob.argmax(dim=-1)\n",
    "\n",
    "            pred_str = wordid2str(pred.detach().cpu(), ref_dict)\n",
    "            pred_collect.extend(pred_str)\n",
    "            gt_collect.extend(gt_str)\n",
    "\n",
    "        score = self.metrics[0].calculate_scores(ground_truth=gt_collect, predict=pred_collect)\n",
    "        print(\"Evaluation accuracy in `{}` split: {:.3f}\".format(split, score))\n",
    "        return score\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def translate(self):\n",
    "        self.model.eval()\n",
    "\n",
    "        pred_collect = []\n",
    "        gt_collect = []\n",
    "        dataloader = self.test_dataloader\n",
    "        for data in dataloader:\n",
    "            graph, tgt, gt_str = data[\"graph_data\"], data[\"tgt_seq\"], data[\"output_str\"]\n",
    "            graph = graph.to(self.device)\n",
    "            if self.use_copy:\n",
    "                oov_dict = prepare_ext_vocab(batch_graph=graph, vocab=self.vocab, device=self.device)\n",
    "                ref_dict = oov_dict\n",
    "            else:\n",
    "                oov_dict = None\n",
    "                ref_dict = self.vocab.out_word_vocab\n",
    "\n",
    "            pred = self.model.translate(batch_graph=graph, oov_dict=oov_dict, beam_size=4, topk=1)\n",
    "\n",
    "            pred_ids = pred[:, 0, :]  # we just use the top-1\n",
    "\n",
    "            pred_str = wordid2str(pred_ids.detach().cpu(), ref_dict)\n",
    "\n",
    "            pred_collect.extend(pred_str)\n",
    "            gt_collect.extend(gt_str)\n",
    "\n",
    "        score = self.metrics[0].calculate_scores(ground_truth=gt_collect, predict=pred_collect)\n",
    "        return score\n",
    "\n",
    "    def load_checkpoint(self, checkpoint_name):\n",
    "        checkpoint_path = os.path.join(self.opt[\"checkpoint_save_path\"], checkpoint_name)\n",
    "        self.model.load_state_dict(torch.load(checkpoint_path))\n",
    "\n",
    "    def save_checkpoint(self, checkpoint_name):\n",
    "        checkpoint_path = os.path.join(self.opt[\"checkpoint_save_path\"], checkpoint_name)\n",
    "        if not os.path.exists(self.opt[\"checkpoint_save_path\"]):\n",
    "            os.makedirs(self.opt[\"checkpoint_save_path\"], exist_ok=True)\n",
    "        torch.save(self.model.state_dict(), checkpoint_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "dac6e553",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ExpressionAccuracy(EvaluationMetricBase):\n",
    "    def __init__(self):\n",
    "        super(ExpressionAccuracy, self).__init__()\n",
    "\n",
    "    def calculate_scores(self, ground_truth, predict):\n",
    "        correct = 0\n",
    "        assert len(ground_truth) == len(predict)\n",
    "        for gt, pred in zip(ground_truth, predict):\n",
    "            if gt == pred:\n",
    "                correct += 1.\n",
    "        return correct / len(ground_truth)\n",
    "\n",
    "def wordid2str(word_ids, vocab):\n",
    "    ret = []\n",
    "    assert len(word_ids.shape) == 2, print(word_ids.shape)\n",
    "    for i in range(word_ids.shape[0]):\n",
    "        id_list = word_ids[i, :]\n",
    "        ret_inst = []\n",
    "        for j in range(id_list.shape[0]):\n",
    "            if id_list[j] == vocab.EOS:\n",
    "                break\n",
    "            token = vocab.getWord(id_list[j])\n",
    "            ret_inst.append(token)\n",
    "        ret.append(\" \".join(ret_inst))\n",
    "    return ret"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a702636a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# config setup\n",
    "config_file = '../config/jobs/gat_bi_sep_dynamic_node_emb.yaml'\n",
    "config = yaml.load(open(config_file, 'r'), Loader=yaml.FullLoader)\n",
    "\n",
    "opt = get_basic_args(graph_construction_name=config[\"graph_construction_name\"],\n",
    "                              graph_embedding_name=config[\"graph_embedding_name\"],\n",
    "                              decoder_name=config[\"decoder_name\"])\n",
    "update_values(to_args=opt, from_args_list=[config, config[\"other_args\"]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cb15c2f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ Using CPU ]\n",
      "Loading pre-built vocab model stored in ../data/jobs/processed/NodeEmbGraph/vocab.pt\n",
      "Start training in split train, Epoch: 0\n",
      "Epoch 0: [10 / 21] loss: 3.991\n",
      "Epoch 0: [20 / 21] loss: 2.538\n",
      "Evaluation accuracy in `test` split: 0.000\n",
      "Best model saved, epoch 0\n",
      "Start training in split train, Epoch: 1\n",
      "Epoch 1: [10 / 21] loss: 1.887\n",
      "Epoch 1: [20 / 21] loss: 1.402\n",
      "Evaluation accuracy in `test` split: 0.000\n",
      "Best model saved, epoch 1\n",
      "Start training in split train, Epoch: 2\n",
      "Epoch 2: [10 / 21] loss: 1.196\n"
     ]
    }
   ],
   "source": [
    "# run the model\n",
    "runner = Jobs(opt)\n",
    "max_score = runner.train()\n",
    "print(\"Train finish, best val score: {:.3f}\".format(max_score))\n",
    "runner.load_checkpoint(\"best.pth\")\n",
    "# runner.evaluate(\"test\")\n",
    "test_score = runner.translate()\n",
    "print('Final test accuracy: {}'.format(test_score))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb547d8c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
